{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reducing Multi-Turn Confusion with LlamaIndex Memory\n",
    "\n",
    "[Recent research](https://arxiv.org/abs/2505.06120) has shown the performance of an LLM significantly degrades given multi-turn conversations.\n",
    "\n",
    "To help avoid this, we can implement a custom short-term and long-term memory in LlamaIndex to ensure that the conversation turns never get too long, and condense the memory as we go.\n",
    "\n",
    "Using the code from this notebook, you may see improvements in your own agents as it works to limit how many turns are in your chat history.\n",
    "\n",
    "**NOTE:** This notebook was tested with `llama-index-core>=0.12.37`, as that version included some fixes to make this work nicely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U llama-index-core llama-index-llms-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "To make this work, we need two things\n",
    "1. A memory block that condenses all past chat messages into a single string while maintaining a token limit\n",
    "2. A `Memory` instance that uses that memory block, and has token limits configured such that multi-turn conversations are always flushed to the memory block for handling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, the custom memory block:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tiktoken\n",
    "from pydantic import Field\n",
    "from typing import List, Optional, Any\n",
    "from llama_index.core.llms import ChatMessage, TextBlock\n",
    "from llama_index.core.memory import Memory, BaseMemoryBlock\n",
    "\n",
    "\n",
    "class CondensedMemoryBlock(BaseMemoryBlock[str]):\n",
    "    current_memory: List[str] = Field(default_factory=list)\n",
    "    token_limit: int = Field(default=50000)\n",
    "    tokenizer: tiktoken.Encoding = tiktoken.encoding_for_model(\n",
    "        \"gpt-4o\"\n",
    "    )  # all openai models use 4o tokenizer these days\n",
    "\n",
    "    async def _aget(\n",
    "        self, messages: Optional[List[ChatMessage]] = None, **block_kwargs: Any\n",
    "    ) -> str:\n",
    "        \"\"\"Return the current memory block contents.\"\"\"\n",
    "        return \"\\n\".join(self.current_memory)\n",
    "\n",
    "    async def _aput(self, messages: List[ChatMessage]) -> None:\n",
    "        \"\"\"Push messages into the memory block. (Only handles text content)\"\"\"\n",
    "        # construct a string for each message\n",
    "        for message in messages:\n",
    "            text_contents = \"\\n\".join(\n",
    "                block.text\n",
    "                for block in message.blocks\n",
    "                if isinstance(block, TextBlock)\n",
    "            )\n",
    "            memory_str = f\"<message role={message.role}>\"\n",
    "\n",
    "            if text_contents:\n",
    "                memory_str += f\"\\n{text_contents}\"\n",
    "\n",
    "            # include additional kwargs, like tool calls, when needed\n",
    "            # filter out injected session_id\n",
    "            kwargs = {\n",
    "                key: val\n",
    "                for key, val in message.additional_kwargs.items()\n",
    "                if key != \"session_id\"\n",
    "            }\n",
    "            if kwargs:\n",
    "                memory_str += f\"\\n({kwargs})\"\n",
    "\n",
    "            memory_str += \"\\n</message>\"\n",
    "            self.current_memory.append(memory_str)\n",
    "\n",
    "        # ensure this memory block doesn't get too large\n",
    "        message_length = sum(\n",
    "            len(self.tokenizer.encode(message))\n",
    "            for message in self.current_memory\n",
    "        )\n",
    "        while message_length > self.token_limit:\n",
    "            self.current_memory = self.current_memory[1:]\n",
    "            message_length = sum(\n",
    "                len(self.tokenizer.encode(message))\n",
    "                for message in self.current_memory\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And then, a `Memory` instance that uses that block while configuring a very limited token limit for the short-term memory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block = CondensedMemoryBlock(name=\"condensed_memory\")\n",
    "\n",
    "memory = Memory.from_defaults(\n",
    "    session_id=\"test-mem-01\",\n",
    "    token_limit=60000,\n",
    "    token_flush_size=5000,\n",
    "    async_database_uri=\"sqlite+aiosqlite:///:memory:\",\n",
    "    memory_blocks=[block],\n",
    "    insert_method=\"user\",\n",
    "    # Prevent the short-term chat history from containing too many turns!\n",
    "    # This limit will effectively mean that the short-term memory is always flushed\n",
    "    chat_history_token_ratio=0.0001,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage\n",
    "\n",
    "Let's explore using this with some dummy messages, and observe how the memory is managed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_messages = [\n",
    "    ChatMessage(role=\"user\", content=\"Hello! My name is Logan\"),\n",
    "    ChatMessage(role=\"assistant\", content=\"Hello! How can I help you?\"),\n",
    "    ChatMessage(role=\"user\", content=\"What is the capital of France?\"),\n",
    "    ChatMessage(role=\"assistant\", content=\"The capital of France is Paris\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await memory.aput_messages(initial_messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, lets add our next user message!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await memory.aput_messages(\n",
    "    [ChatMessage(role=\"user\", content=\"What was my name again?\")]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that, we can explore what the chat history looks like before sending to an LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageRole.USER\n",
      "<memory>\n",
      "<condensed_memory>\n",
      "<message role=MessageRole.USER>\n",
      "Hello! My name is Logan\n",
      "</message>\n",
      "<message role=MessageRole.ASSISTANT>\n",
      "Hello! How can I help you?\n",
      "</message>\n",
      "<message role=MessageRole.USER>\n",
      "What is the capital of France?\n",
      "</message>\n",
      "<message role=MessageRole.ASSISTANT>\n",
      "The capital of France is Paris\n",
      "</message>\n",
      "</condensed_memory>\n",
      "</memory>\n",
      "What was my name again?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "chat_history = await memory.aget()\n",
    "\n",
    "for message in chat_history:\n",
    "    print(message.role)\n",
    "    print(message.content)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Even though we added many messages, it gets condensed into a single user message!\n",
    "\n",
    "Let's try with an actual agent next."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agent Usage\n",
    "\n",
    "Here, we can create a `FunctionAgent` with some simple tools that uses our memory.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.agent.workflow import FunctionAgent\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "\n",
    "def multiply(a: float, b: float) -> float:\n",
    "    \"\"\"Multiply two numbers.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "def divide(a: float, b: float) -> float:\n",
    "    \"\"\"Divide two numbers.\"\"\"\n",
    "    return a / b\n",
    "\n",
    "\n",
    "def add(a: float, b: float) -> float:\n",
    "    \"\"\"Add two numbers.\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "def subtract(a: float, b: float) -> float:\n",
    "    \"\"\"Subtract two numbers.\"\"\"\n",
    "    return a - b\n",
    "\n",
    "\n",
    "llm = OpenAI(model=\"gpt-4.1-mini\")\n",
    "\n",
    "agent = FunctionAgent(\n",
    "    tools=[multiply, divide, add, subtract],\n",
    "    llm=llm,\n",
    "    system_prompt=\"You are a helpful assistant that can do simple math operations with tools.\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block = CondensedMemoryBlock(name=\"condensed_memory\")\n",
    "\n",
    "memory = Memory.from_defaults(\n",
    "    session_id=\"test-mem-01\",\n",
    "    token_limit=60000,\n",
    "    token_flush_size=5000,\n",
    "    async_database_uri=\"sqlite+aiosqlite:///:memory:\",\n",
    "    memory_blocks=[block],\n",
    "    insert_method=\"user\",\n",
    "    # Prevent the short-term chat history from containing too many turns!\n",
    "    # This limit will effectively mean that the short-term memory is always flushed\n",
    "    chat_history_token_ratio=0.0001,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The value of (3214 * 322) / 2 is 517454.0.\n"
     ]
    }
   ],
   "source": [
    "resp = await agent.run(\"What is (3214 * 322) / 2?\", memory=memory)\n",
    "print(resp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageRole.ASSISTANT\n",
      "The value of (3214 * 322) / 2 is 517454.0.\n",
      "\n",
      "MessageRole.USER\n",
      "<memory>\n",
      "<condensed_memory>\n",
      "<message role=MessageRole.USER>\n",
      "What is (3214 * 322) / 2?\n",
      "</message>\n",
      "<message role=MessageRole.ASSISTANT>\n",
      "({'tool_calls': [{'index': 0, 'id': 'call_U78I0CSWETFQlRBCWPpswEmq', 'function': {'arguments': '{\"a\": 3214, \"b\": 322}', 'name': 'multiply'}, 'type': 'function'}, {'index': 1, 'id': 'call_3eFXqalMN9PyiCVEYE073bEl', 'function': {'arguments': '{\"a\": 3214, \"b\": 2}', 'name': 'divide'}, 'type': 'function'}]})\n",
      "</message>\n",
      "<message role=MessageRole.TOOL>\n",
      "1034908\n",
      "({'tool_call_id': 'call_U78I0CSWETFQlRBCWPpswEmq'})\n",
      "</message>\n",
      "<message role=MessageRole.TOOL>\n",
      "1607.0\n",
      "({'tool_call_id': 'call_3eFXqalMN9PyiCVEYE073bEl'})\n",
      "</message>\n",
      "<message role=MessageRole.ASSISTANT>\n",
      "({'tool_calls': [{'index': 0, 'id': 'call_GvtLKm7FCzlaucfYnaxOLBVW', 'function': {'arguments': '{\"a\":1034908,\"b\":2}', 'name': 'divide'}, 'type': 'function'}]})\n",
      "</message>\n",
      "<message role=MessageRole.TOOL>\n",
      "517454.0\n",
      "({'tool_call_id': 'call_GvtLKm7FCzlaucfYnaxOLBVW'})\n",
      "</message>\n",
      "</condensed_memory>\n",
      "</memory>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "current_chat_history = await memory.aget()\n",
    "for message in current_chat_history:\n",
    "    print(message.role)\n",
    "    print(message.content)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perfect! Since the memory didn't have a new user message yet, it added one with our current memory. On the next user message, that memory and the user message would get combined like we saw earlier.\n",
    "\n",
    "Let's try a few follow ups to confirm this is working properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The last question you asked was: \"What is (3214 * 322) / 2?\"\n"
     ]
    }
   ],
   "source": [
    "resp = await agent.run(\n",
    "    \"What was the last question I asked you?\", memory=memory\n",
    ")\n",
    "print(resp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To answer your question \"What is (3214 * 322) / 2?\", I followed these steps:\n",
      "\n",
      "1. First, I multiplied 3214 by 322.\n",
      "2. Then, I divided the result of that multiplication by 2.\n",
      "3. Finally, I provided you with the result of the calculation, which is 517454.0.\n"
     ]
    }
   ],
   "source": [
    "resp = await agent.run(\n",
    "    \"And how did you go about answering that message?\", memory=memory\n",
    ")\n",
    "print(resp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
